Program Listing for File lanenet_node.py
↰ Return to documentation for file (codes/lanekerbnetros/lanenet_node.py
)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# @Time : 17-05-2019
# @Author : Zhou Hui
# @Original site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
# @File : lanenet_node.py
"""
A ros node for lane and kerb detection
"""
import time
import math
import tensorflow as tf
import numpy as np
import cv2
import os
from lanenet_model import lanenet_merge_model
from lanenet_model import lanenet_cluster
from lanenet_model import lanenet_postprocess
from config import global_config
import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge, CvBridgeError
CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]
class lanenet_detector():
def __init__(self):
self.image_topic = rospy.get_param('~image_topic')
self.output_image = rospy.get_param('~output_image')
self.output_lane = rospy.get_param('~output_lane')
self.weight_path = rospy.get_param('~weight_path')
self.use_gpu = rospy.get_param('~use_gpu')
self.save_dir = rospy.get_param('~save_dir')
self.framecount = 0
self.frame_show_ratio = 1
self.frame_save_ratio = 5
self.frame_2 = np.zeros((256, 512*2, 3), np.uint8)
self.nWaitTime = 1
self.s_nWaitTime = 0
self.savecount = 0
self.init_lanenet()
self.bridge = CvBridge()
sub_image = rospy.Subscriber(self.image_topic, Image, self.img_callback, queue_size=1)
self.pub_image = rospy.Publisher(self.output_image, Image, queue_size=1)
def init_lanenet(self):
self.input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
phase_tensor = tf.constant('test', tf.string)
net = lanenet_merge_model.LaneNet(phase=phase_tensor, net_flag='vgg')
self.binary_seg_ret, self.instance_seg_ret = net.inference(input_tensor=self.input_tensor, name='lanenet_model')
self.cluster = lanenet_cluster.LaneNetCluster()
self.postprocessor = lanenet_postprocess.LaneNetPoseProcessor()
saver = tf.train.Saver()
# Set sess configuration
if self.use_gpu:
sess_config = tf.ConfigProto(device_count={'GPU': 1})
else:
sess_config = tf.ConfigProto(device_count={'CPU': 0})
sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
sess_config.gpu_options.allocator_type = 'BFC'
self.sess = tf.Session(config=sess_config)
saver.restore(sess=self.sess, save_path=self.weight_path)
def img_callback(self, data):
try:
cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
except CvBridgeError as e:
print(e)
image = cv2.resize(cv_image, (512, 256), interpolation=cv2.INTER_LINEAR)
copy_image = cv2.resize(cv_image, (512, 256), interpolation=cv2.INTER_LINEAR)
src_image = image
image = image - VGG_MEAN
mask_image = self.inference_net(image, src_image)
out_img_msg = self.bridge.cv2_to_imgmsg(mask_image, "bgr8")
self.pub_image.publish(out_img_msg)
self.frame_2[:image.shape[0],:image.shape[1],:] = copy_image
self.frame_2[:image.shape[0],image.shape[1]:image.shape[1]*2,:] = mask_image
if self.framecount % self.frame_show_ratio == 0:
cv2.imshow("left: input image, right: lane and curb detction; Press 's' to save, ' ' to pause, ESC' to quit", self.frame_2)
#cv2.waitKey(1)
key = cv2.waitKey(self.nWaitTime)
if self.s_nWaitTime == 1:
if self.framecount % self.frame_save_ratio == 0:
cv2.imwrite(os.path.join(self.save_dir, "image_%06i.jpg" %self.savecount), copy_image)
cv2.imwrite(os.path.join(self.save_dir, "curb_%06i.jpg" %self.savecount), mask_image)
cv2.imwrite(os.path.join(self.save_dir, "result_%06i.jpg" %self.savecount), self.frame_2)
print('savecount: {:0>6d}'.format(self.savecount))
self.savecount = self.savecount + 1
if key == 27: # type 'ESC'
if not rospy.is_shutdown():
rospy.signal_shutdown('Quit')
elif key == ord('s'): # type 's'
self.s_nWaitTime = not self.s_nWaitTime
if self.s_nWaitTime:
print('Start saving images..., frame_save_ratio = {}'.format(self.frame_save_ratio))
else:
print('End...')
elif key == 32: # type 'space bar'
self.nWaitTime = not self.nWaitTime
if self.nWaitTime:
print('Unpaused...')
else:
print('Paused...')
print('framecount: {:0>6d}'.format(self.framecount))
self.framecount = self.framecount + 1
def preprocessing(self, img):
image = cv2.resize(img, (512, 256), interpolation=cv2.INTER_LINEAR)
image = image - VGG_MEAN
return image
def inference_net(self, img, original_img):
binary_seg_image, instance_seg_image = self.sess.run([self.binary_seg_ret, self.instance_seg_ret],
feed_dict={self.input_tensor: [img]})
binary_seg_image[0] = self.postprocessor.postprocess(binary_seg_image[0])
mask_image = self.cluster.get_lane_mask(binary_seg_ret=binary_seg_image[0],
instance_seg_ret=instance_seg_image[0],
source_image=original_img)
mask_image_roadf = self.cluster.get_curb_mask(binary_seg_ret=binary_seg_image[0],
instance_seg_ret=instance_seg_image[0],
source_image=mask_image)
#mask_image = cv2.addWeighted(original_img, 0.6, mask_image, 0.4, 0)
return mask_image_roadf
if __name__ == '__main__':
# init args
rospy.init_node('lanenet_node')
lanenet_detector()
rospy.spin()